# Train/evaluate MK-CAViT on ImageNet-style folders with ImageFolder.
import argparse
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms as T
from MK_CAViT import mk_cavit_base
from train_utils import set_seed, train_one_epoch, evaluate_cls


def build_dataloaders(root: str, img_size: int, batch: int, workers: int):
    train_tf = T.Compose([
        T.RandomResizedCrop(img_size, scale=(0.08, 1.0)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    val_tf = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    train = datasets.ImageFolder(os.path.join(root, 'train'), train_tf)
    val   = datasets.ImageFolder(os.path.join(root, 'val'), val_tf)
    train_loader = DataLoader(train, batch_size=batch, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader   = DataLoader(val, batch_size=batch, shuffle=False, num_workers=workers, pin_memory=True)
    return train_loader, val_loader


def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = mk_cavit_base(num_classes=args.num_classes, img_size=args.size).to(device)

    train_loader, val_loader = build_dataloaders(args.root, args.size, args.batch_size, args.workers)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.05)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    for epoch in range(args.epochs):
        tr = train_one_epoch(model, train_loader, optimizer, device, task='cls', mu=args.mu, scaler=scaler)
        ev = evaluate_cls(model, val_loader, device, task='cls')
        print(f"[{epoch+1:03d}/{args.epochs:03d}] "
              f"train loss {tr['loss']:.4f} | val loss {ev['loss']:.4f} | top1 {ev.get('top1', 0.):.2f}%")

    torch.save(model.state_dict(), args.out)


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument('--root', type=str, required=True, help='ImageNet root with train/ and val/ subfolders')
    p.add_argument('--size', type=int, default=224)
    p.add_argument('--num_classes', type=int, default=1000)
    p.add_argument('--epochs', type=int, default=90)
    p.add_argument('--batch_size', type=int, default=64)
    p.add_argument('--workers', type=int, default=8)
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--mu', type=float, default=0.1, help='Weight of F-HGR term in the loss')
    p.add_argument('--amp', action='store_true')
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--out', type=str, default='mk_cavit_imagenet.pth')
    main(p.parse_args())
